Skip to content

Variable elimination#24

Open
francescoTheSantis wants to merge 8 commits intopyc-team:devfrom
francescoTheSantis:variable_elimanation
Open

Variable elimination#24
francescoTheSantis wants to merge 8 commits intopyc-team:devfrom
francescoTheSantis:variable_elimanation

Conversation

@francescoTheSantis
Copy link
Copy Markdown
Collaborator

@francescoTheSantis francescoTheSantis commented Apr 9, 2026

In this PR the following things have been implemented:

  • variable elimination algorithm
  • negative log-likelihood loss
  • factor (different from parametric factor) which is used to handle intermediate factors that are obtained during inference
  • (minor) allow to set prior distribution when generating a synthetic dataset using categorical_toy_dag.py

@gdefe gdefe changed the title Variable elimanation Variable elimination Apr 10, 2026
Dict with 'input' and 'target' for loss computation.
"""
if isinstance(forward_out, dict):
return {'input': forward_out['log_joint'], 'target': target}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we must find a way not to hard-code 'log_joint' here

Dict with 'preds' and 'target' for metric computation.
"""
if isinstance(forward_out, dict):
return {'preds': forward_out['logits'], 'target': target}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

# Distribution type groups.
_BINARY_DISTRIBUTIONS = {Bernoulli, RelaxedBernoulli}
_CATEGORICAL_DISTRIBUTIONS = {Categorical, OneHotCategorical, RelaxedOneHotCategorical}
_CONTINUOUS_DISTRIBUTIONS = {Normal, MultivariateNormal, Delta}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this safe to have such a limited pool of distributions here? what would happen if a user asks for a different distribution? would the code raise error?

target = kwargs['target']

# Flat 2D tensor → fallback to BCE (eval inference path)
if input.ndim == 2:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure of this. why would I need a different behavior for evaluation? a loss should not be aware of this

return torch.tensor(0.0, device=input.device)


class JointNLLLoss(nn.Module):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use a torch class for NLL loss?

this sounds to me too hard coded for the inferences that return log_joint

super().__init__()
self._bce_fallback = nn.BCEWithLogitsLoss()

def forward(self, **kwargs) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the loss signature should be explicit.

Right now, the signature of the dict after filter_output_for_loss should match the forward signature

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants